{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_UaXOSRjDUF9"
   },
   "source": [
    "# Experiment\n",
    "Finally, we can assemble building blocks that we have came across in previous tutorials to conduct our first DRL experiment. In this experiment, we will use [PPO](https://arxiv.org/abs/1707.06347) algorithm to solve the classic CartPole task in Gym."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2QRbCJvDHNAd"
   },
   "source": [
    "## Experiment\n",
    "To conduct this experiment, we need the following building blocks.\n",
    "\n",
    "\n",
    "*   Two vectorized environments, one for training and one for evaluation\n",
    "*   A PPO agent\n",
    "*   A replay buffer to store transition data\n",
    "*   Two collectors to manage the data collecting process, one for training and one for evaluation\n",
    "*   A trainer to manage the training loop\n",
    "\n",
    "<div align=center>\n",
    "<img src=\"https://tianshou.readthedocs.io/en/master/_images/pipeline.png\">\n",
    "\n",
    "</div>\n",
    "\n",
    "Let us do this step by step."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-Hh4E6i0Hj0I"
   },
   "source": [
    "## Preparation\n",
    "Firstly, install Tianshou if you haven't installed it before."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7E4EhiBeHxD5"
   },
   "source": [
    "Import libraries we might need later."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "editable": true,
    "id": "ao9gWJDiHgG-",
    "tags": [
     "hide-cell",
     "remove-output"
    ]
   },
   "outputs": [],
   "source": [
    "%%capture\n",
    "\n",
    "import gymnasium as gym\n",
    "import torch\n",
    "\n",
    "from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n",
    "from tianshou.env import DummyVectorEnv\n",
    "from tianshou.policy import PPOPolicy\n",
    "from tianshou.trainer import OnpolicyTrainer\n",
    "from tianshou.utils.net.common import ActorCritic, Net\n",
    "from tianshou.utils.net.discrete import Actor, Critic\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "QnRg5y7THRYw"
   },
   "source": [
    "## Environment"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "YZERKCGtH8W1"
   },
   "source": [
    "We create two vectorized environments both for training and testing. Since the execution time of CartPole is extremely short, there is no need to use multi-process wrappers and we simply use DummyVectorEnv."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Mpuj5PFnDKVS"
   },
   "outputs": [],
   "source": [
    "env = gym.make(\"CartPole-v1\")\n",
    "train_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(20)])\n",
    "test_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(10)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "BJtt_Ya8DTAh"
   },
   "source": [
    "## Policy\n",
    "Next we need to initialize our PPO policy. PPO is an actor-critic-style on-policy algorithm, so we have to define the actor and the critic in PPO first.\n",
    "\n",
    "The actor is a neural network that shares the same network head with the critic. Both networks' input is the environment observation. The output of the actor is the action and the output of the critic is a single value, representing the value of the current policy.\n",
    "\n",
    "Luckily, Tianshou already provides basic network modules that we can use in this experiment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "_Vy8uPWXP4m_"
   },
   "outputs": [],
   "source": [
    "# net is the shared head of the actor and the critic\n",
    "assert env.observation_space.shape is not None  # for mypy\n",
    "assert isinstance(env.action_space, gym.spaces.Discrete)  # for mypy\n",
    "net = Net(state_shape=env.observation_space.shape, hidden_sizes=[64, 64], device=device)\n",
    "actor = Actor(preprocess_net=net, action_shape=env.action_space.n, device=device).to(device)\n",
    "critic = Critic(preprocess_net=net, device=device).to(device)\n",
    "actor_critic = ActorCritic(actor=actor, critic=critic)\n",
    "\n",
    "# optimizer of the actor and the critic\n",
    "optim = torch.optim.Adam(actor_critic.parameters(), lr=0.0003)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Lh2-hwE5Dn9I"
   },
   "source": [
    "Once we have defined the actor, the critic and the optimizer, we can use them to construct our PPO agent. CartPole is a discrete action space problem, so the distribution of our action space can be a categorical distribution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "OiJ2GkT0Qnbr"
   },
   "outputs": [],
   "source": [
    "dist = torch.distributions.Categorical\n",
    "policy: PPOPolicy = PPOPolicy(\n",
    "    actor=actor,\n",
    "    critic=critic,\n",
    "    optim=optim,\n",
    "    dist_fn=dist,\n",
    "    action_space=env.action_space,\n",
    "    deterministic_eval=True,\n",
    "    action_scaling=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "okxfj6IEQ-r8"
   },
   "source": [
    "`deterministic_eval=True` means that we want to sample actions during training but we would like to always use the best action in evaluation. No randomness included."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "n5XAAbuBZarO"
   },
   "source": [
    "## Collector\n",
    "We can set up the collectors now. Train collector is used to collect and store training data, so an additional replay buffer has to be passed in."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ezwz0qerZhQM"
   },
   "outputs": [],
   "source": [
    "train_collector = Collector[CollectStats](\n",
    "    policy=policy,\n",
    "    env=train_envs,\n",
    "    buffer=VectorReplayBuffer(20000, len(train_envs)),\n",
    ")\n",
    "test_collector = Collector[CollectStats](policy=policy, env=test_envs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ZaoPxOd2hm0b"
   },
   "source": [
    "We use `VectorReplayBuffer` here because it's more efficient to collaborate with vectorized environments, you can simply consider `VectorReplayBuffer` as a a list of ordinary replay buffers."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qBoE9pLUiC-8"
   },
   "source": [
    "## Trainer\n",
    "Finally, we can use the trainer to help us set up the training loop."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "editable": true,
    "id": "i45EDnpxQ8gu",
    "outputId": "b1666b88-0bfa-4340-868e-58611872d988",
    "tags": [
     "remove-output"
    ]
   },
   "outputs": [],
   "source": [
    "result = OnpolicyTrainer(\n",
    "    policy=policy,\n",
    "    train_collector=train_collector,\n",
    "    test_collector=test_collector,\n",
    "    max_epoch=10,\n",
    "    step_per_epoch=50000,\n",
    "    repeat_per_collect=10,\n",
    "    episode_per_test=10,\n",
    "    batch_size=256,\n",
    "    step_per_collect=2000,\n",
    "    stop_fn=lambda mean_reward: mean_reward >= 195,\n",
    ").run()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ckgINHE2iTFR"
   },
   "source": [
    "## Results\n",
    "Print the training result."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "tJCPgmiyiaaX",
    "outputId": "40123ae3-3365-4782-9563-46c43812f10f",
    "tags": []
   },
   "outputs": [],
   "source": [
    "result.pprint_asdict()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "A-MJ9avMibxN"
   },
   "source": [
    "We can also test our trained agent."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "mnMANFcciiAQ",
    "outputId": "6febcc1e-7265-4a75-c9dd-34e29a3e5d21"
   },
   "outputs": [],
   "source": [
    "# Let's watch its performance!\n",
    "policy.eval()\n",
    "result = test_collector.collect(n_episode=1, render=False)\n",
    "print(f\"Final episode reward: {result.returns.mean()}, length: {result.lens.mean()}\")"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
